{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第五章 神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 5.1 试述将线性函数 f(x) 用作神经元激活函数的缺陷。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "如果激活函数是一个线性函数，那么神经网络模型本质上是一个多元线性回归模型。\n",
    "\n",
    "一个多元线性回归模型只能在样本点之间画出一道平面，例如二维数据(x1,x2)，一元线性回归就是画一条直线分开他们，如果实际分类函数是非线性的，那么显然分类效果是不好的。书上99—100提到过。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 5.2 试述使用图 5.2(b) 激活函数的神经元与对率回归的联系。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "隐层的输出结果就是对训练集作对率回归的结果，最终的输出结果就是对上一隐层的输出作对率回归的结果。\n",
    "\n",
    "单隐层神经网络中，若取隐层神经元数量为1，输出层权值为1，其相当于对率回归。\n",
    "\n",
    "PS：有点IVLS的感觉"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 5.3 对于图 5.7 中的 Vih 试推导出 BP 算法中的更新公式(5.13)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "$$\n",
    "\\frac{\\partial E_k}{\\partial v_{ih}} \n",
    "= \\frac{\\partial E_k}{\\partial b_h} \\cdot \\frac{\\partial b_h}{\\partial a_h} \\cdot \\frac{\\partial a_h}{\\partial v_{ih}} \\\\\n",
    "$$\n",
    "显然，\n",
    "$$\n",
    "\\frac{\\partial a_h}{\\partial v_{ih}} = x_i\n",
    "$$\n",
    "由于第$h$个隐层神经元连接着$l$个输出层神经元，故\n",
    "$$\n",
    "\\frac{\\partial E_k}{\\partial b_h} \\\\\n",
    "= \\frac{\\partial \\frac{1}{2}\\sum_{j=1}^{l}(\\hat{y}_j^k - y_j^k)^2}{\\partial b_h} \\\\\n",
    "= \\sum_{j=1}^{l}\\frac{\\partial \\frac{1}{2}(\\hat{y}_j^k - y_j^k)^2}{\\partial b_h} \\\\\n",
    "= \\sum_{j=1}^{l}\\frac{\\partial E_k}{\\partial \\beta_j} \\cdot \\frac{\\partial\\beta_j}{\\partial b_h} \n",
    "$$\n",
    "根据前面算得，\n",
    "$$\n",
    "\\frac{\\partial E_k}{\\partial \\beta_j} = \\hat{y}_j^k(1 - \\hat{y}_j^k)(y_j^k - \\hat{y}_j^k) = g_j\n",
    "$$\n",
    "又\n",
    "$$\n",
    "\\frac{\\partial b_h}{\\partial a_h} = \\frac{\\partial f(a_h - \\gamma_h)}{\\partial a_h} = f(a_h - \\gamma_h)(1 - f(a_h - \\gamma_h)) = b_h(1 - b_h)\n",
    "$$\n",
    "故\n",
    "$$\n",
    "\\frac{\\partial E_k}{\\partial v_{ih}} = \\sum_{j=1}^{l}g_j w_{hj} b_h(1 - b_h)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 5.4 试述式(5.6) 中学习率的取值对神经网络训练的影响。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "若使用梯度下降法，取值过大收敛速度较慢，取值过小有可能在极值附近反复迭代。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 5.5 试编程实现标准 BP 算法和累积 BP 算法，在西瓜数据集 3.0 上分别用这两个算法训练一个单隐层网络，并进行比较。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "一个二分类问题，样本17个，单隐层网络，故输入层神经元17个，隐层神经元数量待定，输出层神经元2个。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "暂时做到这里，已实现基本标准BP和ABP算法，存在缺陷：（1）缺少自动确定隐层神经元个数部分（2）目标函数缺少正则项"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-02-04T12:12:49.159989Z",
     "start_time": "2020-02-04T12:12:48.672691Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-02-04T12:12:49.348496Z",
     "start_time": "2020-02-04T12:12:49.219643Z"
    },
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>编号</th>\n",
       "      <th>色泽</th>\n",
       "      <th>根蒂</th>\n",
       "      <th>敲声</th>\n",
       "      <th>纹理</th>\n",
       "      <th>脐部</th>\n",
       "      <th>触感</th>\n",
       "      <th>密度</th>\n",
       "      <th>含糖率</th>\n",
       "      <th>好瓜</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.697</td>\n",
       "      <td>0.460</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.774</td>\n",
       "      <td>0.376</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.634</td>\n",
       "      <td>0.264</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.608</td>\n",
       "      <td>0.318</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.556</td>\n",
       "      <td>0.215</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.403</td>\n",
       "      <td>0.237</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.481</td>\n",
       "      <td>0.149</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.437</td>\n",
       "      <td>0.211</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.666</td>\n",
       "      <td>0.091</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0.243</td>\n",
       "      <td>0.267</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.245</td>\n",
       "      <td>0.057</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0.343</td>\n",
       "      <td>0.099</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.639</td>\n",
       "      <td>0.161</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>14</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.657</td>\n",
       "      <td>0.198</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>15</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.360</td>\n",
       "      <td>0.370</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>16</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.593</td>\n",
       "      <td>0.042</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>17</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.719</td>\n",
       "      <td>0.103</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    编号  色泽  根蒂  敲声  纹理  脐部  触感     密度    含糖率  好瓜\n",
       "0    1   1   1   1   2   2   1  0.697  0.460   1\n",
       "1    2   0   1   0   2   2   1  0.774  0.376   1\n",
       "2    3   0   1   1   2   2   1  0.634  0.264   1\n",
       "3    4   1   1   0   2   2   1  0.608  0.318   1\n",
       "4    5   2   1   1   2   2   1  0.556  0.215   1\n",
       "5    6   1   0   1   2   0   0  0.403  0.237   1\n",
       "6    7   0   0   1   1   0   0  0.481  0.149   1\n",
       "7    8   0   0   1   2   0   1  0.437  0.211   1\n",
       "8    9   0   0   0   1   0   1  0.666  0.091   0\n",
       "9   10   1   2   2   2   1   0  0.243  0.267   0\n",
       "10  11   2   2   2   0   1   1  0.245  0.057   0\n",
       "11  12   2   1   1   0   1   0  0.343  0.099   0\n",
       "12  13   1   0   1   1   2   1  0.639  0.161   0\n",
       "13  14   2   0   0   1   2   1  0.657  0.198   0\n",
       "14  15   0   0   1   2   0   0  0.360  0.370   0\n",
       "15  16   2   1   1   0   1   1  0.593  0.042   0\n",
       "16  17   1   1   0   1   0   1  0.719  0.103   0"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "read_path = 'D:\\JWE\\Files\\Courses\\数据挖掘\\melon_data.csv'\n",
    "melon_data = pd.read_csv(read_path)\n",
    "data0 = melon_data.copy() # df.copy(deep=True)才是深拷贝，对拷贝对象的操作不会影响原对象\n",
    "data0.loc[data0['好瓜'] == '是',['好瓜']] = 1 # '好瓜'\n",
    "data0.loc[data0['好瓜'] == '否',['好瓜']] = 0 # '坏瓜'\n",
    "character_dict = {character:attr for character,attr in zip(data0.columns.drop(['编号','好瓜']),[set(data0.loc[:,[each]].values.flatten()) for each in data0.columns.drop(['编号','好瓜'])])}\n",
    "temp_dict = character_dict.copy()\n",
    "del temp_dict['密度'],temp_dict['含糖率']\n",
    "for character in temp_dict:\n",
    "    i = 0\n",
    "    for each in character_dict[character]:\n",
    "        data0.loc[data0[character] == each,[character]] = i\n",
    "        i += 1  \n",
    "data0\n",
    "# character_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-02-04T12:12:49.542512Z",
     "start_time": "2020-02-04T12:12:49.512933Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# X = np.c_[data0.drop(['编号','好瓜'],axis=1).values,np.ones(len(data0))]\n",
    "X = data0.drop(['编号','好瓜'],axis=1).values\n",
    "Y = data0['好瓜'].values.reshape(17,1)\n",
    "\n",
    "sigmoid = lambda x:1 / (1 + np.exp(-x))\n",
    "\n",
    "class BP_result(object):\n",
    "    def __init__(self,predict,parameter):\n",
    "        self.parameter = parameter\n",
    "        self.predict = predict\n",
    "\n",
    "def BP(X,Y,d=8,q=10,l=2,eta=0.2,max_iters=10**(3),threshold=10**(-3)):\n",
    "    v_cur = np.random.rand(q,d) \n",
    "    gamma_cur = -np.random.rand(1,q) \n",
    "    w_cur = np.random.rand(l,q) \n",
    "    theta_cur = -np.random.rand(1,l) \n",
    "    while max_iters:\n",
    "        max_iters -= 1\n",
    "        # 参数迭代\n",
    "        for x,y in zip(X,Y):\n",
    "            x = x.reshape(1,len(x))\n",
    "            y = np.array([[y[0],1 - y[0]]])\n",
    "            alpha = np.dot(x,v_cur.T).reshape(1,q)\n",
    "            b = sigmoid(alpha - gamma_cur) \n",
    "            beta = np.dot(b,w_cur.T).reshape(1,l)\n",
    "            y_hat = sigmoid(beta - theta_cur) \n",
    "#             print(y_hat) # 调试\n",
    "            E_k = 1/2*np.square(np.sum(y_hat - y,axis=1))\n",
    "            if E_k.all() <= threshold: # E_k小于阈值，停止迭代\n",
    "                parameter = v_cur,gamma_cur,w_cur,theta_cur\n",
    "                alpha = np.dot(X,v_cur.T)\n",
    "                b = sigmoid(alpha - gamma_cur)\n",
    "                beta = np.dot(b,w_cur.T)\n",
    "                predict = sigmoid(beta - theta_cur)\n",
    "                return BP_result(predict=predict,parameter=parameter)\n",
    "            g = y_hat*(1 - y_hat)*(y - y_hat)\n",
    "#             e = b*(1 - b)*np.sum(w_cur*g.T,axis=0).reshape(1,q)\n",
    "            e = b*(1 - b)*np.dot(g,w_cur).reshape(1,q)\n",
    "            v_cur += eta*np.dot(e.T,x)\n",
    "            gamma_cur += -eta*e\n",
    "            w_cur += eta*np.dot(g.T,b)\n",
    "            theta_cur += -eta*g\n",
    "    parameter = v_cur,gamma_cur,w_cur,theta_cur\n",
    "    alpha = np.dot(X,v_cur.T)\n",
    "    b = sigmoid(alpha - gamma_cur)\n",
    "    beta = np.dot(b,w_cur.T)\n",
    "    predict = sigmoid(beta - theta_cur)\n",
    "    return BP_result(predict=predict,parameter=parameter)\n",
    "\n",
    "def ABP(X,Y,d=8,q=10,l=2,eta=0.2,max_iters=10**(4),threshold=10**(-3)):\n",
    "    n = len(X)\n",
    "    v_cur = np.random.rand(q,d) \n",
    "    gamma_cur = -np.random.rand(n,q) \n",
    "    w_cur = np.random.rand(l,q) \n",
    "    theta_cur = -np.random.rand(n,l) \n",
    "    while max_iters:\n",
    "        max_iters -= 1\n",
    "        # 参数迭代\n",
    "        alpha = np.dot(X,v_cur.T)\n",
    "        b = sigmoid(alpha - gamma_cur)\n",
    "        beta = np.dot(b,w_cur.T)\n",
    "        Y_hat = sigmoid(beta - theta_cur)\n",
    "        E = np.sum(1/2*np.square(np.sum(Y_hat - Y,axis=1).reshape(n,1)),axis=0)/len(Y)\n",
    "        if E.all() <= threshold: # E_k小于阈值，停止迭代 \n",
    "            parameter=v_cur,gamma_cur,w_cur,theta_cur\n",
    "            return BP_result(predict=Y_hat,parameter=parameter)\n",
    "        g = Y_hat*(1 - Y_hat)*(Y - Y_hat)\n",
    "        e = b*(1 - b)*np.dot(g,w_cur)\n",
    "        v_cur += eta*np.dot(e.T,X)\n",
    "        gamma_cur += -eta*e\n",
    "        w_cur += eta*np.dot(g.T,b)\n",
    "        theta_cur += -eta*g\n",
    "    parameter=v_cur,gamma_cur,w_cur,theta_cur\n",
    "    return BP_result(predict=Y_hat,parameter=parameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-02-04T12:12:51.061550Z",
     "start_time": "2020-02-04T12:12:49.741542Z"
    },
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "准确率：1.0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([[0.99145137, 0.00821363],\n",
       "       [0.96074113, 0.04116477],\n",
       "       [0.98380105, 0.01569879],\n",
       "       [0.97273749, 0.02855297],\n",
       "       [0.93940804, 0.05919846],\n",
       "       [0.9392943 , 0.06197716],\n",
       "       [0.87144086, 0.13409763],\n",
       "       [0.86994431, 0.1290403 ],\n",
       "       [0.07498202, 0.92826891],\n",
       "       [0.05330655, 0.94734122],\n",
       "       [0.02074829, 0.97983439],\n",
       "       [0.02110083, 0.97904266],\n",
       "       [0.05042665, 0.94615366],\n",
       "       [0.03829997, 0.96033018],\n",
       "       [0.14834663, 0.84762768],\n",
       "       [0.00281113, 0.99712972],\n",
       "       [0.02858429, 0.97280345]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_BP = BP(X=X,Y=Y,q=15)\n",
    "predict_BP = [1 if each[0]>=each[1] else 0 for each in result_BP.predict]\n",
    "accuracy = np.sum([data0['好瓜'].values.flatten() == predict_BP])/len(data0)\n",
    "print(f'准确率：{accuracy}')\n",
    "predict_BP\n",
    "result_BP.predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-02-04T12:12:52.064715Z",
     "start_time": "2020-02-04T12:12:51.274114Z"
    },
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "准确率：0.5294117647058824\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([[0.99274753, 0.99294701],\n",
       "       [0.99518231, 0.99538995],\n",
       "       [0.99491757, 0.99465035],\n",
       "       [0.99403788, 0.99372276],\n",
       "       [0.99243774, 0.991767  ],\n",
       "       [0.99215616, 0.9922073 ],\n",
       "       [0.99192101, 0.99202286],\n",
       "       [0.99296709, 0.99306767],\n",
       "       [0.00658237, 0.00690458],\n",
       "       [0.00843318, 0.00889915],\n",
       "       [0.00359811, 0.00397762],\n",
       "       [0.003472  , 0.00360604],\n",
       "       [0.00678775, 0.00696404],\n",
       "       [0.00691933, 0.00662707],\n",
       "       [0.00862378, 0.00930973],\n",
       "       [0.00423697, 0.00289196],\n",
       "       [0.00570112, 0.00500253]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_ABP = ABP(X=X,Y=Y,q=15)\n",
    "predict_ABP = [1 if each[0]>=each[1] else 0 for each in result_ABP.predict]\n",
    "accuracy = np.sum([data0['好瓜'].values.flatten() == predict_ABP])/len(data0)\n",
    "print(f'准确率：{accuracy}')\n",
    "predict_ABP\n",
    "result_ABP.predict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 5.6 试设计一个 BP 改进算法，能通过动态调整学习率显著提升收敛速度。编程实现该算法，并选择两个 UCI 数据集与标准 BP 算法进行实验比较。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "日后再补。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
